import datetime
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.optimize import least_squares
from scipy.signal import find_peaks
from sklearn.metrics import mean_squared_error
flu_data = pd.read_excel('flu_cases_1997_2021_raw_data.xlsx')
flu_data.head()
| YEAR | WEEK | Total_cases | |
|---|---|---|---|
| 0 | 1997 | 40 | 0 |
| 1 | 1997 | 41 | 11 |
| 2 | 1997 | 42 | 17 |
| 3 | 1997 | 43 | 8 |
| 4 | 1997 | 44 | 10 |
We introduce two additional columns, which we think may be useful - DATE and DAYS (for counting days since the beginning of data acquisition).
flu_data['DATE'] = pd.to_datetime(flu_data.YEAR.astype(str), format='%Y') + \
pd.to_timedelta(flu_data.WEEK.mul(7).astype(str) + ' days')
flu_data.head()
| YEAR | WEEK | Total_cases | DATE | |
|---|---|---|---|---|
| 0 | 1997 | 40 | 0 | 1997-10-08 |
| 1 | 1997 | 41 | 11 | 1997-10-15 |
| 2 | 1997 | 42 | 17 | 1997-10-22 |
| 3 | 1997 | 43 | 8 | 1997-10-29 |
| 4 | 1997 | 44 | 10 | 1997-11-05 |
days = list(range(0, 7*len(flu_data), 7))
flu_data['DAYS'] = days
flu_data.head()
| YEAR | WEEK | Total_cases | DATE | DAYS | |
|---|---|---|---|---|---|
| 0 | 1997 | 40 | 0 | 1997-10-08 | 0 |
| 1 | 1997 | 41 | 11 | 1997-10-15 | 7 |
| 2 | 1997 | 42 | 17 | 1997-10-22 | 14 |
| 3 | 1997 | 43 | 8 | 1997-10-29 | 21 |
| 4 | 1997 | 44 | 10 | 1997-11-05 | 28 |
Additionally, we rename Total_cases column to maintain uniformity of nomenclature.
flu_data = flu_data.rename(columns={'Total_cases': 'TOTAL_CASES'})
flu_data.head()
| YEAR | WEEK | TOTAL_CASES | DATE | DAYS | |
|---|---|---|---|---|---|
| 0 | 1997 | 40 | 0 | 1997-10-08 | 0 |
| 1 | 1997 | 41 | 11 | 1997-10-15 | 7 |
| 2 | 1997 | 42 | 17 | 1997-10-22 | 14 |
| 3 | 1997 | 43 | 8 | 1997-10-29 | 21 |
| 4 | 1997 | 44 | 10 | 1997-11-05 | 28 |
fig = px.line(flu_data, x='DATE', y="TOTAL_CASES", hover_data=['DATE'])
fig.show()
We can see that data since April 2020 differ from the previous years -- it is probably because of the COVID-19 pandemics. We've decided to delete that data from our dataset.
flu_data = flu_data[~ flu_data['YEAR'].isin([2020, 2021])]
fig = px.line(flu_data, x='DATE', y="TOTAL_CASES", hover_data=['DATE'])
fig.show()
peaks, _ = find_peaks(flu_data['TOTAL_CASES'], distance=24)
fig = px.line(flu_data, x='DATE', y="TOTAL_CASES", hover_data=['DATE'])
fig.add_scatter(
x=flu_data['DATE'].iloc[peaks],
y=flu_data['TOTAL_CASES'][peaks],
mode="markers",
marker_symbol='x',
marker=dict(size=8),
name='local maxima'
)
fig.show()
In the article mentioned in the beginning, author tried to approximate data with linear function. We suggest that 1st degree polynomial is not enough for this approximation and we investigated polynomials of higher degrees.
# choosing dataframe only for peaks
peaks_flu = flu_data.iloc[peaks]
def poly_coefficients(x, coeffs):
o = len(coeffs)
y = 0
for i in reversed(list(range(o))):
y += coeffs[i]*x**i
y += coeffs[-1]
return y
x = flu_data['DAYS']
p_1 = np.polyfit(x=peaks_flu['DAYS'], y=peaks_flu['TOTAL_CASES'], deg=1)
np.polyval(p_1, list(peaks_flu['DAYS'])[-1] + 365)
3767.5588269873924
p_2 = np.polyfit(x=peaks_flu['DAYS'], y=peaks_flu['TOTAL_CASES'], deg=2)
p_3 = np.polyfit(x=peaks_flu['DAYS'], y=peaks_flu['TOTAL_CASES'], deg=3)
p_4 = np.polyfit(x=peaks_flu['DAYS'], y=peaks_flu['TOTAL_CASES'], deg=4)
p_5 = np.polyfit(x=peaks_flu['DAYS'], y=peaks_flu['TOTAL_CASES'], deg=5)
p_6 = np.polyfit(x=peaks_flu['DAYS'], y=peaks_flu['TOTAL_CASES'], deg=6)
fig = make_subplots(rows=3, cols=2)
fig.add_trace(
go.Scatter(x=x, y=poly_coefficients(x, p_1), name='deg=1'),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=x, y=poly_coefficients(x, p_2), name='deg=2'),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=x, y=poly_coefficients(x, p_3), name='deg=3'),
row=2, col=1
)
fig.add_trace(
go.Scatter(x=x, y=poly_coefficients(x, p_4), name='deg=4'),
row=2, col=2
)
fig.add_trace(
go.Scatter(x=x, y=poly_coefficients(x, p_5), name='deg=5'),
row=3, col=1
)
fig.add_trace(
go.Scatter(x=x, y=poly_coefficients(x, p_6), name='deg=6'),
row=3, col=2
)
fig.update_layout(title_text="Polynomial plots for different degree of polynomial")
fig.show()
We conclude that none of investigated polynomials is a good fit and we move on to PDE model.
def SIR(S1: int, I1: int, R1: int, weeks: int,
b: float=0.434, a:float=0.4) -> (list, list, list):
"""Partial differential equation for flu prediction.
Parameters
----------
S1 : Initial number of uninfected people
I1 : Initial number of infected people
R1 : Initial number of people uninfected and immune to flu
weeks : Number of weeks to predict
b : Infection rate
a : Recovery rate
"""
S = [S1]
I = [I1]
R = [R1]
N = S1 + I1 + R1
for i in range(weeks):
S.append(S[i] - b * ((I[i] * S[i]) / N))
I.append(I[i] + b * ((I[i] * S[i]) / N) - a * I[i])
R.append(R[i] + a * I[i])
return S, I, R
We choose subset of data investigated in the article and examined intial prediction.
subset = flu_data[flu_data['DATE'].between('2016-01-01', '2016-04-16')]
s_1 = 20000
i_1 = subset['TOTAL_CASES'].iloc[0]
n_weeks = len(subset) - 1
S, I, R = SIR(
S1=s_1,
I1=i_1,
R1=0,
weeks=n_weeks,
b=0.434,
a=0.4,
)
plt.plot(I, c='r', label='pred')
plt.plot(list(range(n_weeks + 1)), subset['TOTAL_CASES'], c='k', label='true')
plt.title('Initial prediction')
plt.xlabel('week')
plt.ylabel('infections')
plt.legend()
plt.show()
mean_squared_error(I, subset['TOTAL_CASES'])
2187746.038418362
It's clearly bad, so we decided to optimize parameters, using least squares method.
def param_optimization(initial_guesses: list) -> float:
"""Finding optimal parameters using least squares."""
s_1, a, b = initial_guesses
_, I_pred, _ = SIR(S1=s_1, I1=i_1, R1=0,
weeks=n_weeks, b=b, a=a)
# compute mean squared error
return np.mean(
[(I_pred[t] - subset['TOTAL_CASES'].iloc[t]) ** 2 for t in range(len(subset))]
)
s_1, a, b = least_squares(param_optimization, [20000, 0.4, 0.434], bounds=(0, np.inf)).x
s_1, a, b
(19999.86756245367, 0.5857148081863623, 1.038152476470345)
S, I, R = SIR(
S1=s_1,
I1=i_1,
R1=0,
weeks=n_weeks,
b=b,
a=a,
)
plt.plot(I, c='r', label='pred')
plt.plot(list(range(n_weeks + 1)), subset['TOTAL_CASES'], c='k', label='true')
plt.title('Prediction after optimization')
plt.xlabel('week')
plt.ylabel('infections')
plt.legend()
plt.show()
mean_squared_error(I, subset['TOTAL_CASES'])
38187.20107554164
After hyperparameter optimization, our model was able to correctly predict flu infections.